from symbol.builder import add_anchor_to_arg
from symbol.builder import ResNetV1bFPN as Backbone
from models.FPN.builder import FPNNeck as Neck
from models.FPN.builder import FPNRoiAlign as RoiExtractor
from models.FPN.builder import FPNBbox2fcHead as BboxHead
from mxnext.complicate import normalizer_factory

from models.maskrcnn.builder import MaskFasterRcnn as Detector
from models.maskrcnn.builder import MaskFPNRpnHead as RpnHead
from models.maskrcnn.builder import MaskFasterRcnn4ConvHead as MaskHead
from models.maskrcnn.builder import BboxPostProcessor
from models.maskrcnn.process_output import process_output


def get_config(is_train):
    class General:
        # number of iteration for print the metrics to stdout
        log_frequency = 10
        # the directory name for the experiment, the default is the name of config
        name = __name__.rsplit("/")[-1].rsplit(".")[-1]  
        # batch size per GPU
        batch_image = 2 if is_train else 1
        # use FP16 for weight and activation
        # recommend to toggle when you are training on Volta or later GPUs
        fp16 = False
        # number of threads used for the data loader
        # this term affects both the CPU utilization and the MEM usage
        # lower this if you are training on Desktop
        loader_worker = 8
        # switch the built in profile to find the bottleneck of network
        profile = False


    class KvstoreParam:
        # the type of communicator used to sync model parameters
        kvstore     = "nccl"  # "local", "aggregated"
        batch_image = General.batch_image
        # GPUs to use
        gpus        = [0, 1, 2, 3, 4, 5, 6, 7]
        fp16        = General.fp16


    class NormalizeParam:
        # the type of normalizer used for network 
        # see also ModelParam.pretrain.fixed_param for the freeze of gamma/beta
        normalizer = normalizer_factory(type="fixbn")  # freeze bn stats
        normalizer = normalizer_factory(type="localbn")  # use bn stats in one GPU
        normalizer = normalizer_factory(type="syncbn", ndev=len(KvstoreParam.gpus))  # use bn stats across GPUs
        normalizer = normalizer_factory(type="gn")  # use GroupNorm


    class BackboneParam:
        # you can control the FP16 option and normalizer for each individual component
        fp16 = General.fp16
        normalizer = NormalizeParam.normalizer
        # some backbone component accept additional configs, like the depth for ResNet
        depth = 50


    class NeckParam:
        # you can control the FP16 option and normalizer for each individual component
        fp16 = General.fp16
        normalizer = NormalizeParam.normalizer


    class RpnParam:
        # you can control the FP16 option and normalizer for each individual component
        fp16 = General.fp16
        normalizer = NormalizeParam.normalizer
        batch_image = General.batch_image
        # use ONNX-compatible proposal operator instead of the one written in C++/CUDA
        nnvm_proposal = True
        # use in-network rpn target operator instead of the label generated by data loader
        # if your network is quite fast, the CPU might not feed the labels fast enough
        # else you can offload the rpn target generation to CPU to save GPU resources
        nnvm_rpn_target = False

        # anchor grid generated are used in the rpn target assign and proposal decoding
        class anchor_generate:
            scale = (8,)
            ratio = (0.5, 1.0, 2.0)
            stride = (4, 8, 16, 32, 64)
            # number of anchors per image
            image_anchor = 256
            # to avoid generate the same anchor grid more than once
            # we cache an anchor grid in the arg_params
            # max_side specify the max side of resized input image
            # 3000 is a safe bet, increase it if necessary
            max_side = 1400

        # valid when use nnvm_rpn_target, controls the rpn target assign
        class anchor_assign:
            # number of pixels the anchor box could extend out of the image border
            allowed_border = 0
            # iou lower bound with groundtruth box for foreground anchor 
            pos_thr = 0.7
            # iou upper bound with groundtruth box for background anchor
            neg_thr = 0.3
            # every groundtruth box will match the anchors overlaps most with it by default
            # increase the threshold to avoid matching low quality anchors
            min_pos_thr = 0.0
            # number of anchors per image
            image_anchor = 256
            # fraction of foreground anchors per image
            pos_fraction = 0.5

        # rpn head structure
        class head:
            # number of channels for the 3x3 conv in rpn head
            conv_channel = 256
            # mean and std for rpn regression target
            mean = (0, 0, 0, 0)
            std = (1, 1, 1, 1)

        # the proposal generation for RCNN
        class proposal:
            # number of top-scored proposals to take before NMS
            pre_nms_top_n = 2000 if is_train else 1000
            # number of top-scored proposals to take after NMS
            post_nms_top_n = 2000 if is_train else 1000
            # proposal NMS threshold
            nms_thr = 0.7
            # min proposal box to keep, 0 means keep all
            min_bbox_side = 0

        # the proposal sampling for RCNN during training
        class subsample_proposal:
            # add gt to proposals
            proposal_wo_gt = False
            # number of proposals sampled per image during training
            image_roi = 512
            # the maxinum fraction of foreground proposals
            fg_fraction = 0.25
            # iou lower bound with gt bbox for foreground proposals 
            fg_thr = 0.5
            # iou upper bound with gt bbox for background proposals 
            bg_thr_hi = 0.5
            # iou lower bound with gt bbox for background proposals 
            # set to non-zero value could remove some trivial background proposals
            bg_thr_lo = 0.0

        # the target encoding for RCNN bbox head
        class bbox_target:
            # 1(background) + num_class
            # could be num_class if using sigmoid activition instead of softmax one
            num_reg_class = 1 + 80
            # share the regressor for all classes
            class_agnostic = False
            # the mean, std, and weight for bbox head regression target
            weight = (1.0, 1.0, 1.0, 1.0)
            mean = (0.0, 0.0, 0.0, 0.0)
            std = (0.1, 0.1, 0.2, 0.2)


    class BboxParam:
        # you can control the FP16 option and normalizer for each individual component
        fp16 = General.fp16
        normalizer = NormalizeParam.normalizer
        # num_class may be different from RpnParam.bbox_target.num_reg_class
        # if the class_agnostic regressor is adopted
        num_class   = 1 + 80
        image_roi   = RpnParam.subsample_proposal.image_roi
        batch_image = General.batch_image

        class regress_target:
            class_agnostic = RpnParam.bbox_target.class_agnostic
            mean = RpnParam.bbox_target.mean
            std = RpnParam.bbox_target.std


    class MaskParam:
        # you can control the FP16 option and normalizer for each individual component
        fp16        = General.fp16
        normalizer  = NormalizeParam.normalizer
        # output resolution of mask head
        resolution  = 28
        # number of channels for 3x3 convs in mask head
        dim_reduced = 256
        # mask head only trains on foreground proposals
        # so we discard all the background proposals to save computation
        num_fg_roi  = int(RpnParam.subsample_proposal.image_roi * RpnParam.subsample_proposal.fg_fraction)


    class RoiParam:
        # you can control the FP16 option and normalizer for each individual component
        fp16 = General.fp16
        normalizer = NormalizeParam.normalizer
        # Each RoI is pooled into an out_size x out_size fixed-length representation
        out_size = 7
        # the total stride of the feature map to pool from
        stride = (4, 8, 16, 32)
        # FPN specific configs
        # objects of size in [224^2, 448^2) will be assgin to P4
        roi_canonical_scale = 224
        roi_canonical_level = 4


    class MaskRoiParam:
        # you can control the FP16 option and normalizer for each individual component
        fp16 = General.fp16
        normalizer = NormalizeParam.normalizer
        # Each RoI is pooled into an out_size x out_size fixed-length representation
        out_size = 14
        # the total stride of the feature map to pool from
        stride = (4, 8, 16, 32)
        # FPN specific configs
        # objects of size in [224^2, 448^2) will be assgin to P4
        roi_canonical_scale = 224
        roi_canonical_level = 4


    class DatasetParam:
        # specify the roidbs to read for training/validation
        if is_train:
            # == coco_train2017
            image_set = ("coco_train2014", "coco_valminusminival2014")
        else:
            # == coco_val2017
            image_set = ("coco_minival2014", )


    class OptimizeParam:
        class optimizer:
            type = "sgd"
            # learning rate will automaticly adapt to different batch size
            # the base learning rate is 0.02 for 16 images
            lr = 0.01 / 8 * len(KvstoreParam.gpus) * KvstoreParam.batch_image
            momentum = 0.9
            wd = 0.0001
            clip_gradient = None

        class schedule:
            # correspond to the 1x, 2x, ... training schedule
            mult = 2
            begin_epoch = 0
            end_epoch = 6 * mult
            lr_mode = "step"  # or "cosine"
            # lr step factor
            lr_factor = 0.1
            # lr step iterations
            if mult <= 1:
                lr_iter = [60000 * mult * 16 // (len(KvstoreParam.gpus) * KvstoreParam.batch_image),
                           80000 * mult * 16 // (len(KvstoreParam.gpus) * KvstoreParam.batch_image)]
            else:
                # follow the practice in arXiv:1811.08883
                # reduce the lr in the last 60k and 20k iterations
                lr_iter = [-60000 * 16 // (len(KvstoreParam.gpus) * KvstoreParam.batch_image),
                           -20000 * 16 // (len(KvstoreParam.gpus) * KvstoreParam.batch_image)]

        # follow the practice in arXiv:1706.02677
        class warmup:
            type = "gradual"
            lr = 0.01 / 8 * len(KvstoreParam.gpus) * KvstoreParam.batch_image / 3
            iter = 500


    class TestParam:
        # detection below min_det_score will be removed in the evaluation
        min_det_score = 0.05
        # only the top max_det_per_image detecitons will be evaluated
        max_det_per_image = 100

        # callback, useful in multi-scale testing
        process_roidb = lambda x: x
        # callback, useful in scale-aware post-processing
        process_output = lambda x, y: process_output(x, y)

        # the model name and epoch used during test
        # by default the last checkpoint is employed
        # user can override this with --epoch N when invoking script
        class model:
            prefix = "experiments/{}/checkpoint".format(General.name)
            epoch = OptimizeParam.schedule.end_epoch

        class nms:
            type = "nms"  # or "softnms"
            thr = 0.5

        # we make use of the coco test toolchain
        # if no coco format annotation file is specified
        # test script will generate one on the fly from roidb
        class coco:
            annotation = "data/coco/annotations/instances_minival2014.json"

    # compose the components to for a detector
    backbone = Backbone(BackboneParam)
    neck = Neck(NeckParam)
    rpn_head = RpnHead(RpnParam, MaskParam)
    roi_extractor = RoiExtractor(RoiParam)
    mask_roi_extractor = RoiExtractor(MaskRoiParam)
    bbox_head = BboxHead(BboxParam)
    mask_head = MaskHead(BboxParam, MaskParam, MaskRoiParam)
    bbox_post_processer = BboxPostProcessor(TestParam)
    detector = Detector()
    if is_train:
        train_sym = detector.get_train_symbol(backbone, neck, rpn_head, roi_extractor, mask_roi_extractor, bbox_head, mask_head)
        test_sym = None
    else:
        train_sym = None
        test_sym = detector.get_test_symbol(backbone, neck, rpn_head, roi_extractor, mask_roi_extractor, bbox_head, mask_head, bbox_post_processer)


    class ModelParam:
        train_symbol = train_sym
        test_symbol = test_sym

        # training model from scratch
        from_scratch = False
        # use random seed when initializating
        random = True
        # sublinear memory checkpointing
        memonger = False
        # checkpointing up to a layer
        # recompute early stage of a network is cheaper
        memonger_until = "stage3_unit21_plus"

        class pretrain:
            # the model name and epoch used for initialization
            prefix = "pretrain_model/resnet%s_v1b" % BackboneParam.depth
            epoch = 0
            # any params partially match the fixed_param will be fixed
            # fixed params will not be updated
            fixed_param = ["conv0", "stage1", "gamma", "beta"]
            # any params partially match the excluded_param will not be fixed
            excluded_param = ["mask_fcn"]

        # callback, useful for adding cached anchor or complex initialization
        def process_weight(sym, arg, aux):
            for stride in RpnParam.anchor_generate.stride:
                add_anchor_to_arg(
                    sym, arg, aux, RpnParam.anchor_generate.max_side,
                    stride, RpnParam.anchor_generate.scale,
                    RpnParam.anchor_generate.ratio)


    # data processing
    class NormParam:
        # mean/std for input image
        mean = tuple(i * 255 for i in (0.485, 0.456, 0.406)) # RGB order
        std = tuple(i * 255 for i in (0.229, 0.224, 0.225))


    # data processing
    class ResizeParam:
        # the input is resized to a short side not exceeding short
        # and a long side not exceeding long
        short = 800
        long = 1333


    # SimpleDet is written in MXNet symbolic API which features the fastest 
    # execution while requires static input shape
    # All the inputs are padded to the maximum shape item on the dataset
    class PadParam:
        # the resized input is padded to short x long with 0 in bottom-right corner
        short = 800
        long = 1333
        
        max_num_gt = 100
        max_len_gt_poly = 2500


    # this control the rpn target generation offloaded to CPU data loader
    # refer to RpnParam.anchor_generate for more infos
    class AnchorTarget2DParam:
        def __init__(self):
            self.generate = self._generate()

        class _generate:
            def __init__(self):
                self.stride = (4, 8, 16, 32, 64)
                # the shorts and longs have to be pre-computed since the 
                # loader knows nothing of the network 
                # the downsampled side can be calculated as ceil(side / 2)
                self.short = (200, 100, 50, 25, 13)
                self.long = (334, 167, 84, 42, 21)
            scales = (8, )
            aspects = (0.5, 1.0, 2.0)

        class assign:
            allowed_border = 0
            pos_thr = 0.7
            neg_thr = 0.3
            min_pos_thr = 0.0

        class sample:
            image_anchor = 256
            pos_fraction = 0.5


    # align blobs name between loader and network
    class RenameParam:
        mapping = dict(image="data")


    from core.detection_input import ReadRoiRecord, Resize2DImageBbox, \
        ConvertImageFromHwcToChw, Flip2DImageBbox, Pad2DImageBbox, \
        RenameRecord, Norm2DImage

    from models.maskrcnn.input import PreprocessGtPoly, EncodeGtPoly, \
        Resize2DImageBboxMask, Flip2DImageBboxMask, Pad2DImageBboxMask

    from models.FPN.input import PyramidAnchorTarget2D

    # modular data augmentation design
    if is_train:
        transform = [
            ReadRoiRecord(None),
            Norm2DImage(NormParam),
            PreprocessGtPoly(),
            Resize2DImageBboxMask(ResizeParam),
            Flip2DImageBboxMask(),
            EncodeGtPoly(PadParam),
            Pad2DImageBboxMask(PadParam),
            ConvertImageFromHwcToChw(),
            RenameRecord(RenameParam.mapping)
        ]
        data_name = ["data"]
        label_name = ["im_info", "gt_bbox", "gt_poly"]
        if not RpnParam.nnvm_rpn_target:
            transform.append(PyramidAnchorTarget2D(AnchorTarget2DParam()))
            label_name += ["rpn_cls_label", "rpn_reg_target", "rpn_reg_weight"]
    else:
        transform = [
            ReadRoiRecord(None),
            Norm2DImage(NormParam),
            Resize2DImageBbox(ResizeParam),
            ConvertImageFromHwcToChw(),
            RenameRecord(RenameParam.mapping)
        ]
        data_name = ["data", "im_info", "im_id", "rec_id"]
        label_name = []

    import core.detection_metric as metric
    from models.maskrcnn.metric import SigmoidCELossMetric
    from mxboard import SummaryWriter

    # summary writer logs metric to tensorboard for a better track of training
    sw = SummaryWriter(logdir="./tflogs", flush_secs=5)

    rpn_acc_metric = metric.AccWithIgnore(
        name="RpnAcc",
        output_names=["rpn_cls_loss_output", "rpn_cls_label_blockgrad_output"],
        label_names=[],
        summary=sw
    )
    rpn_l1_metric = metric.L1(
        name="RpnL1",
        output_names=["rpn_reg_loss_output", "rpn_cls_label_blockgrad_output"],
        label_names=[],
        summary=sw
    )
    box_acc_metric = metric.AccWithIgnore(
        name="RcnnAcc",
        output_names=["bbox_cls_loss_output", "bbox_label_blockgrad_output"],
        label_names=[],
        summary=sw
    )
    box_l1_metric = metric.L1(
        name="RcnnL1",
        output_names=["bbox_reg_loss_output", "bbox_label_blockgrad_output"],
        label_names=[],
        summary=sw
    )
    mask_cls_metric = SigmoidCELossMetric(
        name="MaskCE",
        output_names=["mask_loss_output"],
        label_names=[],
        summary=sw
    )

    metric_list = [rpn_acc_metric, rpn_l1_metric, box_acc_metric, box_l1_metric, mask_cls_metric]

    return General, KvstoreParam, RpnParam, RoiParam, BboxParam, DatasetParam, \
           ModelParam, OptimizeParam, TestParam, \
           transform, data_name, label_name, metric_list
